#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright © 2005 Progiciels Bourbeau-Pinard inc.
# François Pinard <pinard@iro.umontreal.ca>, 2005.

"""\
Execute a validation suite built in pylib's py.test style.

Usage: pytest [OPTION]... [PATH]...

Options:
  -h   Print this help and exit right away.
  -v   Produce one line per test, instead of a dot per test.
  -n   Do not capture stdout nor stderr.
  -p   Profile validation suite (through "lsprof").

  -f PREFIX     Use PREFIX instead of "test_" for file names.
  -s FILE       Save ordinals of failed tests, one per line.
  -o ORDINALS   Only consider tests having these execution ordinals.
  -k PATTERN    Only retain tests which names match PATTERN.
  -x PATTERN    Exclude tests which names match PATTERN.
  -l LIMIT      Stop the validation suite after LIMIT errors.

If -l is not used, the validation will stop after 10 errors.  ORDINALS is
a sequence of comma-separated integers.  Options -k, -o and -x may be
repeated; then a test should match at least one of -k options if any,
one of -o options is any, and none of -x options.

If PATH names a file, the name should match "test_.*\.py".  If PATH names
a directory, it is recursively searched to find so matching file names.
If no PATH are given, the current directory is implied.

Test progression is first displayed on standard error.  Then, unless
-s is selected, failed tests are detailed on stdout and, if there is at
least one such failed test, the return status of this program is non-zero.
"""

# This tool implements a minimal set of specifications stolen from the
# excellent Codespeak's py.test, at a time I really needed py.test to be
# more Unicode-aware.

if PYTHON3:
    import inspect, io, locale, os, sys, time, traceback
else:
    # Jython misses the yield keyword.
    from __future__ import generators

    try:
        import locale
    except ImportError:
        # Jython misses this module.
        locale = None

    try:
        sorted
    except NameError:
        # Jython misses this function.

        def sorted(sequence):
            sequence = list(sequence)
            sequence.sort()
            return sequence

    __metaclass__ = type
    import inspect, os, sys, time, traceback
    from StringIO import StringIO

# How many displayable characters in an output line.
WIDTH = 79

class Limit_Reached(Exception):
    pass

class Main:
    prefix = 'test_'
    pattern = []
    exclusion = []
    ordinals = []
    verbose = False
    profile = False
    limit = 10
    capture = True
    save = None

    # For handling setup/teardown laziness.
    delayed_setup_module = None
    delayed_setup_class = None
    did_tests_in_module = False
    did_tests_in_class = False

    def main(self, *arguments):
        if sys.getdefaultencoding() == 'ascii':
            sys.stdout = Friendly_StreamWriter(sys.stdout)
            sys.stderr = Friendly_StreamWriter(sys.stderr)
        import getopt
        options, arguments = getopt.getopt(arguments, 'f:hk:l:no:ps:vx:')
        for option, value in options:
            if option == '-f':
                self.prefix = value
            elif option == '-h':
                sys.stdout.write(__doc__)
                return
            elif option == '-k':
                self.pattern.append(value)
            elif option == '-l':
                self.limit = int(value)
            elif option == '-n':
                self.capture = False
            elif option == '-o':
                self.ordinals += [
                    int(text) for text in value.replace(',', ' ').split()]
            elif option == '-p':
                self.profile = True
            elif option == '-s':
                self.save = value
            elif option == '-v':
                self.verbose = True
            elif option == '-x':
                self.exclusion.append(value)
        if not arguments:
            arguments = '.',
        if self.pattern:
            import re
            self.pattern = re.compile('|'.join(self.pattern))
        else:
            self.pattern = None
        if self.exclusion:
            import re
            self.exclusion = re.compile('|'.join(self.exclusion))
        else:
            self.exclusion = None
        write = sys.stderr.write
        self.failures = []
        self.total_count = 0
        self.skip_count = 0
        start_time = time.time()
        if self.profile:
            try:
                import lsprof
            except ImportError:
                write("WARNING: profiler unavailable.\n")
                self.profiler = None
            else:
                self.profiler = lsprof.Profiler()
        else:
            self.profiler = None
        try:
            try:
                for argument in arguments:
                    for file_name in self.each_file(argument):
                        self.identifier = file_name
                        self.column = 0
                        self.counter = 0
                        directory, base = os.path.split(file_name)
                        sys.path.insert(0, directory)
                        try:
                            try:
                                module = __import__(base[:-3])
                            except ImportError:
                                if self.save:
                                    self.failures.append(self.total_count + 1)
                                else:
                                    if PYTHON3:
                                        tracing = io.StringIO()
                                    else:
                                        tracing = StringIO()
                                    traceback.print_exc(file=tracing)
                                    self.failures.append(
                                        (self.total_count + 1, file_name,
                                         None, None,
                                         None, None,
                                         str(tracing.getvalue())))
                            else:
                                self.handle_module(file_name, module)
                        finally:
                            del sys.path[0]
                        if self.counter and not self.verbose:
                            text = '(%d)' % self.counter
                            if self.column + 1 + len(text) >= WIDTH:
                                write('\n%5d ' % self.counter)
                            else:
                                text = ' ' + text
                            write(text + '\n')
            except KeyboardInterrupt:
                if not self.verbose:
                    write('\n')
                write('\n*** INTERRUPTION! ***\n')
            except Limit_Reached:
                if not self.verbose:
                    write('\n')
                if not self.save:
                    if len(self.failures) == 1:
                        write('\n*** ONE ERROR ALREADY! ***\n')
                    else:
                        write('\n*** %d ERRORS ALREADY! ***\n' % self.limit)
        finally:
            if self.profiler is not None:
                stats = lsprof.Stats(self.profiler.getstats())
                stats.sort('inlinetime')
                write('\n')
                stats.pprint(15)
            if self.failures:
                if len(self.failures) == 1:
                    text = "one FAILED test"
                else:
                    text = "%d FAILED tests" % len(self.failures)
                first = False
            else:
                text = ''
                first = True
            good_count = (self.total_count - self.skip_count
                          - len(self.failures))
            if good_count:
                if first:
                    if good_count == 1:
                        text += "one good test"
                    else:
                        text += "%d good tests" % good_count
                    first = False
                else:
                    if good_count == 1:
                        text += ", one good"
                    else:
                        text += ", %d good" % good_count
            if self.skip_count:
                if first:
                    if self.skip_count == 1:
                        text += "one skipped test"
                    else:
                        text += "%d skipped tests" % self.skip_count
                    first = False
                else:
                    if self.skip_count == 1:
                        text += ", one skipped"
                    else:
                        text += ", %d skipped" % self.skip_count
            if first:
                text = "No test"
            summary = ("\nSummary: %s in %.2f seconds.\n"
                       % (text, time.time() - start_time))
            write(summary)
        if self.save:
            handle = file(self.save, 'w')
            for ordinal in self.failures:
                handle.write('%d\n' % ordinal)
            handle.close()
        else:
            write = sys.stdout.write
            for (ordinal, prefix, function, arguments, stdout, stderr,
                 tracing) in self.failures:
                write('\n' + '=' * WIDTH + '\n')
                write('%d. %s\n' % (ordinal, prefix))
                if function and function.__name__ != os.path.basename(prefix):
                    write("    Fonction %s\n" % function.__name__)
                if arguments:
                    for counter, argument in enumerate(arguments):
                        write("    Arg %d = %r\n" % (counter + 1, argument))
                for buffer, titre in ((stdout, 'STDOUT'), (stderr, 'STDERR')):
                    if buffer:
                        write('\n' + titre + ' >>>\n')
                        write(buffer)
                        if not buffer.endswith('\n'):
                            write('\n')
                write('-' * WIDTH + '\n')
                write(tracing)
            if self.failures:
                write(summary)
                sys.exit(1)

    def each_file(self, path):
        if os.path.isdir(path):
            stack = [path]
            while stack:
                directory = stack.pop(0)
                for base in sorted(os.listdir(directory)):
                    file_name = os.path.join(directory, base)
                    if os.path.isdir(file_name):
                        stack.append(file_name)
                    elif base.startswith(self.prefix) and base.endswith('.py'):
                        yield file_name
        else:
            directory, base = os.path.split(path)
            if base.startswith(self.prefix) and base.endswith('.py'):
                yield path

    def handle_module(self, prefix, module):
        collection = []
        for name, objet in inspect.getmembers(module):
            if name.startswith('Test') and inspect.isclass(objet):
                if getattr(object, 'disabled', False):
                    continue
                minimum = None
                for _, method in inspect.getmembers(objet, inspect.ismethod):
                    if PYTHON3:
                        number = method.__func__.__code__.co_firstlineno
                    else:
                        number = method.im_func.func_code.co_firstlineno
                    if minimum is None or number < minimum:
                        minimum = number
                if minimum is not None:
                    collection.append((minimum, name, objet, False))
            elif name.startswith('test_') and inspect.isfunction(objet):
                if PYTHON3:
                    code = objet.__code__
                else:
                    code = objet.func_code
                collection.append((code.co_firstlineno, name, objet,
                                   bool(code.co_flags & 32)))
        if not collection:
            return
        self.delayed_setup_module = None
        self.did_tests_in_module = False
        if hasattr(module, 'setup_module'):
            self.delayed_setup_module = module.setup_module, module
        for _, name, objet, generator in sorted(collection):
            self.delayed_setup_class = None
            self.did_tests_in_class = False
            if inspect.isclass(objet):
                if not getattr(object, 'disabled', False):
                    self.handle_class(prefix + '/' + name, objet)
            else:
                self.handle_function(prefix + '/' + name, objet,
                                      generator, None)
        if self.did_tests_in_module and hasattr(module, 'teardown_module'):
            module.teardown_module(module)

    def handle_class(self, prefix, classe):
        collection = []
        for name, method in inspect.getmembers(classe, inspect.ismethod):
            if name.startswith('test_'):
                if PYTHON3:
                    code = method.__func__.__code__
                else:
                    code = method.im_func.func_code
                collection.append((code.co_firstlineno, name, method,
                                   bool(code.co_flags & 32)))
        if not collection:
            return
        instance = classe()
        if hasattr(instance, 'setup_class'):
            self.delayed_setup_module = module.setup_class, classe
        for _, name, method, generator in sorted(collection):
            self.handle_function(prefix + '/' + name, getattr(instance, name),
                                  generator, instance)
        if self.did_tests_in_class and hasattr(instance, 'teardown_class'):
            instance.teardown_class(classe)

    def handle_function(self, prefix, function, generator, instance):
        if generator:
            for counter, arguments in enumerate(function()):
                if PYTHON3:
                    self.launch_test(prefix + '/' + str(counter + 1),
                                     arguments[0], arguments[1:], instance)
                else:
                    self.launch_test(prefix + '/' + unicode(counter + 1),
                                     arguments[0], arguments[1:], instance)
        else:
            self.launch_test(prefix, function, (), instance)

    def launch_test(self, prefix, function, arguments, instance):
        # Check if this test should be retained.
        if (self.exclusion is not None and self.exclusion.search(prefix)
              or self.pattern is not None and not self.pattern.search(prefix)
              or self.ordinals and self.total_count+1 not in self.ordinals):
            self.mark_progression(prefix, None)
            return
        # This test should definitely be executed.
        if self.delayed_setup_module is not None:
            self.delayed_setup_module[0](self.delayed_setup_module[1])
            self.delayed_setup_module = None
        if self.delayed_setup_class is not None:
            self.delayed_setup_class[0](self.delayed_setup_class[1])
            self.delayed_setup_class = None
        if instance is not None and hasattr(instance, 'setup_method'):
            instance.setup_method(function)
        if self.capture:
            saved_stdout = sys.stdout
            saved_stderr = sys.stderr
            if PYTHON3:
                stdout = sys.stdout = io.StringIO()
                stderr = sys.stderr = io.StringIO()
            else:
                stdout = sys.stdout = StringIO()
                stderr = sys.stderr = StringIO()
        self.activate_profiling()
        try:
            try:
                function(*arguments)
            except KeyboardInterrupt:
                exception = sys.exc_info()
                raise
            except:
                exception = sys.exc_info()
            else:
                exception = None
        finally:
            self.deactivate_profiling()
            if self.capture:
                sys.stdout = saved_stdout
                sys.stderr = saved_stderr
                stdout = stdout.getvalue()
                stderr = stderr.getvalue()
            else:
                stdout = None
                stderr = None
            if exception is None:
                self.mark_progression(prefix, True)
            else:
                self.mark_progression(prefix, False)
                if self.save:
                    self.failures.append(self.total_count)
                else:
                    if PYTHON3:
                        tracing = io.StringIO()
                    else:
                        tracing = StringIO()
                    traceback.print_exception(*exception, file=tracing)
                    self.failures.append(
                        (self.total_count, prefix, function, arguments,
                         stdout, stderr, str(tracing.getvalue())))
            if instance is not None and hasattr(instance, 'teardown_method'):
                instance.teardown_method(function)
            self.did_tests_in_class = True
            self.did_tests_in_module = True
        if exception is not None and len(self.failures) == self.limit:
            raise Limit_Reached

    def mark_progression(self, prefix, succes):
        self.total_count += 1
        if succes is None:
            self.skip_count += 1
        else:
            write = sys.stderr.write
            if self.verbose:
                write('%5d. [%s] %s\n' % (self.total_count, prefix,
                                           ('FAILED', 'ok')[succes]))
            else:
                if self.column == WIDTH:
                    write('\n')
                    self.column = 0
                if not self.column:
                    if self.counter:
                        text = '%5d ' % (self.counter + 1)
                    else:
                        text = self.identifier + ' '
                    write(text)
                    self.column = len(text)
                if PYTHON3:
                    write('E·'[succes])
                else:
                    write(u'E·'[succes])
                self.column += 1
                self.counter += 1

    def activate_profiling(self):
        if self.profiler is not None:
            self.profiler.enable(subcalls=True, builtins=True)

    def deactivate_profiling(self):
        if self.profiler is not None:
            self.profiler.disable()

class Friendly_StreamWriter:
    # Avoid some Unicode nightmares, by allowing both unicode and
    # UTF-8 str strings to be written (given our sources are all UTF-8).

    def __init__(self, stream):
        if locale is None:
            encoding = 'UTF-8'
        else:
            encoding = locale.getpreferredencoding()
        import codecs
        writer = codecs.getwriter(encoding)
        self.stream = writer(stream, 'backslashreplace')

    def write(self, text):
        if PYTHON3:
            if not isinstance(text, str):
                text = str(text, 'UTF-8')
        else:
            if not isinstance(text, unicode):
                text = unicode(text, 'UTF-8')
        self.stream.write(text)

    def writelines(self, lines):
        for line in lines:
            self.write(line)

run = Main()
main = run.main

class ExceptionExpected(Exception):
    pass

def raises(expected, *args, **kws):
    try:
        if PYTHON3:
            if isinstance(args[0], str) and not kws:
                eval(args[0])
            else:
                args[0](*args[1:], **kws)
        else:
            if isinstance(args[0], unicode) and not kws:
                eval(args[0])
            else:
                args[0](*args[1:], **kws)
    except expected:
        return
    else:
        raise ExceptionExpected("Exception did not happen.")

if __name__ == '__main__':
    main(*sys.argv[1:])
